from typing import Optional
from torch.utils.data.dataloader import DataLoader

from dataset.cyclone import CycloneDataset, CycloneSample
from dataset.cyclone_diff import CycloneDiffusionDataset, CycloneDiffSample
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist


def check_partial_holdouts(dataset_cfg):
    # check that each trajectory in partial holdouts also appears in training
    for entry in dataset_cfg.partial_holdouts:
        file = entry.trajectory
        if file not in dataset_cfg.training_trajectories:
            raise ValueError(
                f"Trajectory '{file}' in partial_holdouts is not in training_trajectories."
            )
    return


def get_data(cfg, rank: Optional[int] = None):
    data_cfg = cfg.dataset
    model_cfg = cfg.autoencoder
    training_cfg = cfg.training
    validation_cfg = cfg.validation

    assert data_cfg.name in ["cyclone"]

    augmentations = None
    # TODO(diff) autmentations? and set up in dateset directly

    use_ddp = dist.is_initialized()
    if data_cfg.name == "cyclone":
        partial_holdouts = {}
        if data_cfg.partial_holdouts:
            # validate config
            check_partial_holdouts(data_cfg)
            for entry in data_cfg.partial_holdouts:
                file = entry.trajectory
                last_n = entry.last_n
                partial_holdouts[file] = last_n

        ae_cond = list(cfg.autoencoder.conditioning)
        # TODO(diff) make cond loading the same across both diff and normal datasets

        trainset = CycloneDiffusionDataset(
            active_keys=data_cfg.active_keys,
            input_fields=["df", "phi", "flux"],  # TODO figure out how to deal with eval
            path=data_cfg.path,
            split="train",
            random_seed=cfg.seed,
            normalization=data_cfg.normalization,
            normalization_scope=data_cfg.normalization_scope,
            spatial_ifft=data_cfg.spatial_ifft,
            bundle_seq_length=model_cfg.bundle_seq_length,
            trajectories=data_cfg.training_trajectories,
            partial_holdouts=partial_holdouts,
            cond_filters=data_cfg.training_cond_filters,
            subsample=data_cfg.subsample,
            log_transform=data_cfg.log_transform,
            split_into_bands=data_cfg.split_into_bands,
            minmax_beta1=data_cfg.minmax_beta1,
            minmax_beta2=data_cfg.minmax_beta2,
            offset=data_cfg.offset,
            separate_zf=data_cfg.separate_zf,
            num_workers=data_cfg.num_workers,
            real_potens=data_cfg.real_potens,
            conditions=ae_cond,
            stage=cfg.stage,
        )

        holdout_trajectories_valset = CycloneDiffusionDataset(
            active_keys=data_cfg.active_keys,
            input_fields=["df", "phi", "flux"],
            path=data_cfg.path,
            split="val",
            random_seed=cfg.seed,
            normalization=data_cfg.normalization,
            normalization_scope=data_cfg.normalization_scope,
            normalization_stats=trainset.norm_stats,
            spatial_ifft=data_cfg.spatial_ifft,
            bundle_seq_length=model_cfg.bundle_seq_length,
            trajectories=data_cfg.validation_trajectories,
            cond_filters=data_cfg.eval_cond_filters,
            subsample=data_cfg.subsample,
            log_transform=data_cfg.log_transform,
            split_into_bands=data_cfg.split_into_bands,
            minmax_beta1=data_cfg.minmax_beta1,
            minmax_beta2=data_cfg.minmax_beta2,
            offset=data_cfg.offset,
            separate_zf=data_cfg.separate_zf,
            num_workers=data_cfg.num_workers,
            real_potens=data_cfg.real_potens,
            conditions=ae_cond,
            stage=cfg.stage,
        )

        trainloader = DataLoader(
            trainset,
            training_cfg.batch_size,
            num_workers=training_cfg.num_workers,
            shuffle=True if not use_ddp else False,
            collate_fn=trainset.collate,
            pin_memory=training_cfg.pin_memory,
            sampler=DistributedSampler(trainset) if use_ddp else None,
            persistent_workers=True,
            prefetch_factor=max(training_cfg.num_workers // 2, 1),
        )

        holdout_trajectories_valloader = DataLoader(
            holdout_trajectories_valset,
            validation_cfg.batch_size,
            num_workers=training_cfg.num_workers,
            shuffle=False,
            collate_fn=holdout_trajectories_valset.collate,
            pin_memory=training_cfg.pin_memory,
            sampler=(
                DistributedSampler(holdout_trajectories_valset) if use_ddp else None
            ),
            persistent_workers=True,
            prefetch_factor=max(training_cfg.num_workers // 2, 1),
        )

        if partial_holdouts:
            holdout_samples_valset = CycloneDiffusionDataset(
                active_keys=data_cfg.active_keys,
                input_fields=["df", "phi", "flux"],
                path=data_cfg.path,
                split="val",
                random_seed=cfg.seed,
                normalization=data_cfg.normalization,
                normalization_scope=data_cfg.normalization_scope,
                normalization_stats=trainset.norm_stats,
                spatial_ifft=data_cfg.spatial_ifft,
                bundle_seq_length=model_cfg.bundle_seq_length,
                trajectories=data_cfg.training_trajectories,
                partial_holdouts=partial_holdouts,
                cond_filters=data_cfg.eval_cond_filters,
                subsample=data_cfg.subsample,
                log_transform=data_cfg.log_transform,
                minmax_beta1=data_cfg.minmax_beta1,
                minmax_beta2=data_cfg.minmax_beta2,
                offset=data_cfg.offset,
                separate_zf=data_cfg.separate_zf,
                num_workers=data_cfg.num_workers,
                real_potens=data_cfg.real_potens,
                conditions=list(ae_cond.union(diff_cond)),
                stage=cfg.stage,
            )
            holdout_samples_valloader = DataLoader(
                holdout_samples_valset,
                validation_cfg.batch_size,
                num_workers=training_cfg.num_workers,
                shuffle=False,
                collate_fn=holdout_samples_valset.collate,
                pin_memory=training_cfg.pin_memory,
                sampler=(
                    DistributedSampler(holdout_samples_valset) if use_ddp else None
                ),
            )
        if rank == 0 or rank is None:
            ratio = len(holdout_trajectories_valset) / len(trainset)
            print(
                f"Using {len(trainset.files)} sims for training.\n"
                f"Using {len(holdout_trajectories_valset.files)} sims for validation.\n"
                f"Train: {len(trainset)} samples\n"
                f"Holdout trajectories: {len(holdout_trajectories_valset)} samples\n"
                f"Validation ratio: {ratio:.2f}\n"
            )

        if partial_holdouts:
            if rank == 0 or rank is None:
                print(f"Holdout samples Val: {len(holdout_samples_valset)} samples")
            return (
                (trainset, holdout_trajectories_valset, holdout_samples_valset),
                (
                    trainloader,
                    holdout_trajectories_valloader,
                    holdout_samples_valloader,
                ),
                augmentations,
            )

    return (
        (trainset, holdout_trajectories_valset),
        (trainloader, holdout_trajectories_valloader),
        augmentations,
    )


__all__ = [
    "get_data",
    "CycloneDataset",
    "CycloneSample",
    "CycloneDiffusionDataset",
    "CycloneDiffSample",
]
